import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import numpy as np
import torch.utils.data as Data


def func_stddev(x):
    return x.std(1)


def func_mean(x):
    return x.mean(1)


def func_min(x):
    return x.min(1).values


def func_max(x):
    return x.max(1).values


def func_sum(x):
    return x.sum(1)


def func_zscore(x):
    return x.mean(1) / x.std(1)


def func_decaylinear(x):
    weights = torch.LongTensor(list(reversed(range(1, 1 + x.shape[1]))))
    x = x * weights.unsqueeze(0).unsqueeze(2)
    return x


def func_corr(x):
    # quote https://www.zhihu.com/question/450669124
    """传入一个tensor格式的矩阵x(x.shape(m,n))，输出其相关系数矩阵"""
    f = (x.shape[0] - 1) / x.shape[0]  # 方差调整系数
    x_reducemean = x - torch.mean(x, axis=0)
    numerator = torch.matmul(x_reducemean.T, x_reducemean) / x.shape[0]
    var_ = x.var(axis=0).reshape(x.shape[1], 1)
    denominator = torch.sqrt(torch.matmul(var_, var_.T)) * f
    x = numerator / denominator
    x = x.triu(1).reshape(-1)
    x = x[special_index]
    return x


class ts_stddev(nn.Module):
    def __init__(self, d):
        super(ts_stddev, self).__init__()
        self.d = d

    def forward(self, x):
        x = torch.cat([func_stddev(x[:, i:i + self.d, :]).unsqueeze(0)
                       for i in range(x.shape[1] - self.d)]).transpose(0, 1)
        x = x.resize(x.shape[0], x.shape[1] * x.shape[2])
        return x


class ts_mean(nn.Module):
    def __init__(self, d):
        super(ts_mean, self).__init__()
        self.d = d

    def forward(self, x):
        x = torch.cat([func_mean(x[:, i:i + self.d, :]).unsqueeze(0)
                       for i in range(x.shape[1] - self.d)]).transpose(0, 1)
        x = x.resize(x.shape[0], x.shape[1] * x.shape[2])
        return x


class ts_max(nn.Module):
    def __init__(self, d):
        super(ts_max, self).__init__()
        self.d = d

    def forward(self, x):
        x = torch.cat([func_max(x[:, i:i + self.d, :]).unsqueeze(0)
                       for i in range(x.shape[1] - self.d)]).transpose(0, 1)
        x = x.resize(x.shape[0], x.shape[1] * x.shape[2])
        return x


class ts_min(nn.Module):
    def __init__(self, d):
        super(ts_min, self).__init__()
        self.d = d

    def forward(self, x):
        x = torch.cat([func_min(x[:, i:i + self.d, :]).unsqueeze(0)
                       for i in range(x.shape[1] - self.d)]).transpose(0, 1)
        x = x.resize(x.shape[0], x.shape[1] * x.shape[2])
        return x


class ts_sum(nn.Module):
    def __init__(self, d):
        super(ts_sum, self).__init__()
        self.d = d

    def forward(self, x):
        x = torch.cat([func_sum(x[:, i:i + self.d, :]).unsqueeze(0)
                       for i in range(x.shape[1] - self.d)]).transpose(0, 1)
        x = x.resize(x.shape[0], x.shape[1] * x.shape[2])
        return x


class ts_zscore(nn.Module):
    def __init__(self, d):
        super(ts_zscore, self).__init__()
        self.d = d

    def forward(self, x):
        x = torch.cat([func_zscore(x[:, i:i + self.d, :]).unsqueeze(0)
                       for i in range(x.shape[1] - self.d)]).transpose(0, 1)
        x = x.resize(x.shape[0], x.shape[1] * x.shape[2])
        return x


class ts_decaylinear(nn.Module):
    def __init__(self, d):
        super(ts_decaylinear, self).__init__()
        self.d = d

    def forward(self, x):
        x = torch.cat([func_decaylinear(x[:, i:i + self.d, :]).unsqueeze(0)
                       for i in range(x.shape[1] - self.d)]).transpose(0, 1)
        x = x.resize(x.shape[0], x.shape[1] * x.shape[2])
        return x


class ts_return(nn.Module):
    def __init__(self, d):
        super(ts_return, self).__init__()
        self.d = d

    def forward(self, x):
        x = torch.cat([((x[:, i, :] - x[:, i - self.d, :]) / x[:, i - self.d] - 1).unsqueeze(0)
                       for i in range(self.d, x.shape[1])]).transpose(0, 1)
        x = x.resize(x.shape[0], x.shape[1] * x.shape[2])
        return x


class ts_decaylinear(nn.Module):
    def __init__(self, d):
        super(ts_decaylinear, self).__init__()
        self.d = d

    def forward(self, x):
        x = torch.cat([((x[:, i, :] - x[:, i - self.d, :]) / x[:, i - self.d] - 1).unsqueeze(0)
                       for i in range(self.d, x.shape[1])]).transpose(0, 1)
        x = x.resize(x.shape[0], x.shape[1] * x.shape[2])
        return x


class ts_corr(nn.Module):
    def __init__(self, d):
        super(ts_corr, self).__init__()
        self.d = d

    def forward(self, x):
        x = torch.cat([torch.cat([x[i][ii:ii + self.d].unsqueeze(0) for ii in range(x[i].shape[0] - self.d)]).reshape(
            -1).unsqueeze(0)
                       for i in range(x.shape[0])])
        return x
